Skip to content

[Merge stable to main] Llama3.3-70b and 3.1-8b - Fix sampling parameters#36476

Merged
djordje-tt merged 23 commits intomainfrom
divanovic/stable_llama3.3_70b
Feb 6, 2026
Merged

[Merge stable to main] Llama3.3-70b and 3.1-8b - Fix sampling parameters#36476
djordje-tt merged 23 commits intomainfrom
divanovic/stable_llama3.3_70b

Conversation

@djordje-tt
Copy link
Contributor

@djordje-tt djordje-tt commented Jan 26, 2026

Ticket

#36325

Problem description

This PR fixes couple of different issues for Llama3.3-70b:

  • Non-uniform seeding

  • Penalty trap bug

  • Penalty bugs for Llama3.1-8b

  • batched prefill determinism

  • diff between batched and non-batched prefill

  • missing logprobs support for Llama3.3-70b

  • Fixes same sampling parameters for Llama3.1-8b

What's changed

  • Bring over the log-probs support for Galaxy (optional log-softmaxed logits output), matching the behavior already validated on stable in TT-Metal, vLLM nightly, and Models CI.
  • Integrate the deterministic seeding flow (host-side RNG + SamplingSeedManager + ttnn.manual_seed usage before ttnn.sampling) so prefill + decode produce deterministic sequences across repeats when seeds are fixed.
  • Ensure the penalties path matches the shared implementation, fixing the earlier divergence across users.
  • Updated matmul configs to support same behaviour across batched and non-batched prefill with couple additional fixes for divergence.

Performance numbers on text_demo in t/s/u:

branch without penalties with penalties
branch 71.88 t/s/u 42.36 t/s/u
main 72.05 t/s/u -

TTFT:
68.5ms -> 73.9ms drop due to disabling use_2d_grid in rms norm is expected.

Checklist

Model tests

Last pipelines list 6th Feb:

@djordje-tt
Copy link
Contributor Author

/codeowners ping

@tenstorrent-github-bot
Copy link

CodeOwners Group Analysis

This PR requires approval from one member of each of the following groups:

Summary: 2 pending groups, 0 approved groups

Group Information:


Note: At least one approval from each group is sufficient.

@tenstorrent-github-bot
Copy link

Hi Ambrose Ling (@alingTT), Stuti Raizada (@sraizada-tt), Utku Aydonat (@uaydonat), Mark O'Connor (@yieldthought), this PR [Merge stable to main] Llama3.3-70b - Fix sampling parameters by Djordje Ivanovic (@djordje-tt) needs your approval/review to merge this.

ttnn.copy(input_a=tt_logits_list[0], input_b=self.tt_logits_accumulated[user_id])
# On-device sampling for prefill
if do_device_sampling:
padded_batch = 32
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR merges sampling and log-probs enhancements from stable to main for Llama3.3-70b on Galaxy, addressing several key issues including non-uniform seeding, penalty bugs, batched prefill determinism, and missing log-probs support.

Changes:

  • Introduces deterministic seeding flow with host-side RNG via SeedManager and ttnn.manual_seed integration for reproducible sampling across prefill and decode
  • Adds log-probs support for Galaxy with on-device log-softmax calculation using distributed reduction operations
  • Updates matmul configurations to ensure consistency between batched and non-batched prefill paths
  • Fixes penalty application bugs (frequency/presence/repetition penalties) with corrected tensor type handling and proper masking
  • Extends prefill path to support on-device sampling with sharded logits accumulation

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
qwen_model_config.py Updated matmul program config overwrite_per_core_k values for seq_len=128 operations to improve batched prefill consistency
model_config.py Similar matmul config adjustments for Llama model (non-Qwen variant)
llama_model.py Added process_output_prefill_logits method for on-device sampling; updated return signatures to include log_probs
llama_mlp.py Threaded batch_size parameter through forward_prefill for CCL operations
llama_decoder.py Passed batch_size to feed_forward layer for proper CCL buffer selection
llama_ccl.py Added log-probs persistent buffers; updated reduce_scatter/all_gather for batched operations; removed WO from reduce_scatter (now uses WO_AG)
llama_attention.py Updated SDPA program config and WO buffer key for batched prefill; deallocate intermediate output_11SH tensor
generator.py (galaxy) Implemented on-device prefill sampling with logits accumulation; added slot-based parameter scattering; integrated SeedManager
text_demo.py Added test configuration for batch-32 with non-uniform sampling and log-probs
demo_qwen_decode.py Updated to extract and track log-probs during decode
outputs_batch_1.json Expected output reference update (generation variance)
utils.py Refactored LogProbsCalculator for Galaxy (32-device) support with proper all-gather operations and dimension handling
test_sampling.py Added test_log_probs_with_sub_core_grids_on_galaxy for validating log-probs on 32-device mesh
tt_sampling.py Added force_argmax_sampling optimization path; integrated manual_seed with per-user seed tensors
tt_penalties.py Fixed penalty application order and type casting bugs; removed vocab expansion workaround; proper -1 padding handling
generator.py (sampling) Introduced SeedManager class for deterministic per-user RNG; updated trace key to include force_argmax flag
Comments suppressed due to low confidence (2)

models/common/utils.py:303

  • Several intermediate tensors are not deallocated, which could lead to memory leaks. Consider deallocating: global_idx_tilized_tensor after line 229, chip_ids_tensor after line 249, remainder_tensor after line 239, out after line 276, and log_global_exp_sum after line 276. Also, relevant_logits after line 303.
    def _prepare_relevant_logits(self, logits_tensor: ttnn.Tensor, global_idx_tensor: ttnn.Tensor):
        """
        Prepare global idx tensor with correct values on all devices.
        """
        size_per_device = logits_tensor.shape[-1]

        # convert global_idx_tensor to ttnn.TILE_LAYOUT
        global_idx_tilized_tensor = ttnn.to_layout(global_idx_tensor, ttnn.TILE_LAYOUT, **self.common_args)

        # TODO: Raise an issue on this since for UINT_32 ttnn.div produces incorrect output (all zeros)
        global_idx_tilized_tensor = ttnn.typecast(global_idx_tilized_tensor, ttnn.float32, **self.common_args)

        # Get chip_id for each user based on global_idx values in global_idx_tensor
        chip_ids_tensor = ttnn.div(
            global_idx_tilized_tensor,
            size_per_device,
            round_mode="floor",
            memory_config=ttnn.DRAM_MEMORY_CONFIG,
            **self.common_args,
        )

        # Get local index for each user based on global_idx values in global_idx_tensor
        remainder_tensor = ttnn.remainder(
            global_idx_tilized_tensor,
            size_per_device,
            memory_config=ttnn.DRAM_MEMORY_CONFIG,
            **self.common_args,
        )

        # Convert remainder_tensor to int32
        remainder_tensor = ttnn.typecast(remainder_tensor, ttnn.uint32, **self.common_args)
        # convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT
        remainder_tensor = ttnn.to_layout(remainder_tensor, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
        remainder_tensor = ttnn.reshape(remainder_tensor, (1, 1, 32, 1), **self.common_args)
        remainder_tensor = ttnn.to_layout(remainder_tensor, ttnn.TILE_LAYOUT, **self.common_args)

        # Get logits for each user on each chip based on local index
        selected_logits_tensor = ttnn.gather(logits_tensor, dim=3, index=remainder_tensor, **self.common_args)

        # convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT
        selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
        selected_logits_tensor = ttnn.reshape(selected_logits_tensor, (1, 1, 1, 32), **self.common_args)
        selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.TILE_LAYOUT, **self.common_args)
        # Compare mask to chip_ids tensor and select correct positions for each user on all chips inplace
        ttnn.eq_(chip_ids_tensor, self.mask, **self.common_args)

        # Multiply selected_logits_tensor with chip_ids_tensor to get expected logits for each user
        selected_logits_tensor = ttnn.multiply(selected_logits_tensor, chip_ids_tensor, **self.common_args)

        # All gather logits across all devices
        selected_logits_tensor = self._perform_all_gather(
            selected_logits_tensor,
            dim=1,
            num_links=1,
            buffer_key="LOGPROBS_LOGITS",
        )

        selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
        selected_logits_tensor = ttnn.reshape(selected_logits_tensor, (1, 1, 8, 32), **self.common_args)
        selected_logits_tensor = ttnn.to_layout(selected_logits_tensor, ttnn.TILE_LAYOUT, **self.common_args)

        # Apply sum over device dimension to get logits for each user on all chips
        selected_logits_tensor = ttnn.sum(selected_logits_tensor, dim=2, keepdim=True, **self.common_args)

        return selected_logits_tensor

    def _calculate_log_probs(self, sampled_logits_tensor: ttnn.Tensor):
        """
        Calculate log-probs for a given logits tensor with formula:
        log-prob(x) = logits(x) - global_max - log(global_exp_sum)
        """
        out = ttnn.subtract(sampled_logits_tensor, self.global_max, **self.common_args)
        log_global_exp_sum = ttnn.log(self.global_exp_sum, **self.common_args)
        # Subtract and put result to self.output_tensor
        ttnn.subtract(out, log_global_exp_sum, output_tensor=self.output_tensor, **self.common_args)

    def calculate_log_probs(
        self,
        logits_tensor: ttnn.Tensor,
        indices_tensor: ttnn.Tensor,
    ):
        """
        Calculate log-probs for a given logits tensor and indices tensor.
        """
        if not self.enable_log_probs:
            return self.output_tensor

        if self.mesh_device.get_num_devices() not in [8, 32]:
            return self.output_tensor

        # Calculating log-probs requires bfloat16 precision for near-stable sum-exp calculation
        if logits_tensor.dtype == ttnn.bfloat8_b:
            logits_tensor = ttnn.typecast(logits_tensor, ttnn.bfloat16, **self.common_args)

        # Compute global max and global sum(exp(logits - global_max)) for each chip
        self._compute_global_stats(logits_tensor)

        # Prepare relevant logits for each user on each chip
        relevant_logits = self._prepare_relevant_logits(logits_tensor, indices_tensor)

        # Calculate log-probs for each user on each chip and stores in self.output_tensor
        self._calculate_log_probs(relevant_logits)

models/demos/llama3_70b_galaxy/tt/llama_mlp.py:323

        if 1024 <= seq_len < 4096:


freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs)

freq_term_bf16 = ttnn.typecast(freq_term, ttnn.bfloat16, **op_kwargs)
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate variable freq_term is created but never deallocated. This could lead to memory leaks in long-running scenarios. Consider adding freq_term.deallocate() after line 50.

Suggested change
freq_term_bf16 = ttnn.typecast(freq_term, ttnn.bfloat16, **op_kwargs)
freq_term_bf16 = ttnn.typecast(freq_term, ttnn.bfloat16, **op_kwargs)
freq_term.deallocate()

Copilot uses AI. Check for mistakes.

# presence
presence_term = ttnn.multiply(context.output_mask, context.presence_penalties, **op_kwargs)
presence_term = ttnn.multiply(
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate variable presence_term is created but never deallocated after creating presence_term_bf16. This could lead to memory leaks. Consider adding presence_term.deallocate() after line 41.

Copilot uses AI. Check for mistakes.
Comment on lines +166 to +200
local_max_tensor = ttnn.max(logits_tensor, dim=-1, keepdim=True, **self.common_args)

# All-gather local max to get global max
gathered_max_tensors = ttnn.all_gather(
gathered_max_tensors = self._perform_all_gather(
local_max_tensor,
dim=3,
dim=1,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
cluster_axis=None,
topology=ttnn.Topology.Linear,
buffer_key="LOGPROBS_MAX_REDUCTION",
)
self.global_max = ttnn.max(gathered_max_tensors, dim=-1, keepdim=True)
# TODO: Convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT
gathered_max_tensors = ttnn.to_layout(gathered_max_tensors, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
gathered_max_tensors = ttnn.reshape(gathered_max_tensors, (1, 1, 8, 32), **self.common_args)
gathered_max_tensors = ttnn.to_layout(gathered_max_tensors, ttnn.TILE_LAYOUT, **self.common_args)

self.global_max = ttnn.max(gathered_max_tensors, dim=2, keepdim=True, **self.common_args)

global_max_to_subtract = ttnn.to_layout(self.global_max, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
global_max_to_subtract = ttnn.reshape(global_max_to_subtract, (1, 1, 32, 1), **self.common_args)
global_max_to_subtract = ttnn.to_layout(global_max_to_subtract, ttnn.TILE_LAYOUT, **self.common_args)

# Calculate stable local sum-exp using subtract of global-max from each local logit
subtracted_tensor = ttnn.subtract(logits_tensor, self.global_max)
sum_exp_tensor = ttnn.sum(ttnn.exp(subtracted_tensor), dim=-1, keepdim=True)
subtracted_tensor = ttnn.subtract(logits_tensor, global_max_to_subtract, **self.common_args)
exp_tensor = ttnn.exp(subtracted_tensor, **self.common_args)
sum_exp_tensor = ttnn.sum(exp_tensor, dim=-1, keepdim=True, **self.common_args)

# All-gather stable local sum-exp to get global sum-exp
gathered_sum_exp_tensors = ttnn.all_gather(
gathered_sum_exp_tensors = self._perform_all_gather(
sum_exp_tensor,
dim=3,
dim=1,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
cluster_axis=None,
topology=ttnn.Topology.Linear,
buffer_key="LOGPROBS_SUM_EXP_REDUCTION",
)
self.global_exp_sum = ttnn.sum(gathered_sum_exp_tensors, dim=-1, keepdim=True)
gathered_sum_exp_tensors = ttnn.to_layout(gathered_sum_exp_tensors, ttnn.ROW_MAJOR_LAYOUT, **self.common_args)
gathered_sum_exp_tensors = ttnn.reshape(gathered_sum_exp_tensors, (1, 1, 8, 32), **self.common_args)
gathered_sum_exp_tensors = ttnn.to_layout(gathered_sum_exp_tensors, ttnn.TILE_LAYOUT, **self.common_args)

# reshape global_max and global_exp_sum to support same output shape as sampling output -> (1, 1, 1, 32)
# convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT
self.global_max = ttnn.to_layout(self.global_max, ttnn.ROW_MAJOR_LAYOUT)
self.global_max = ttnn.reshape(self.global_max, (1, 1, 1, 32))
self.global_max = ttnn.to_layout(self.global_max, ttnn.TILE_LAYOUT)

# convert to ROW_MAJOR_LAYOUT due to memory clobbering which affects all ttnn.reshape ops with TILE_LAYOUT
self.global_exp_sum = ttnn.to_layout(self.global_exp_sum, ttnn.ROW_MAJOR_LAYOUT)
self.global_exp_sum = ttnn.reshape(self.global_exp_sum, (1, 1, 1, 32))
self.global_exp_sum = ttnn.to_layout(self.global_exp_sum, ttnn.TILE_LAYOUT)
self.global_exp_sum = ttnn.sum(gathered_sum_exp_tensors, dim=2, keepdim=True, **self.common_args)
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several intermediate tensors are not deallocated, which could lead to memory leaks. Consider deallocating: local_max_tensor after line 173, subtracted_tensor after line 187, exp_tensor after line 188, sum_exp_tensor after line 188, and global_max_to_subtract after line 186.

Copilot uses AI. Check for mistakes.
Comment on lines +353 to +354
if sampling_params.top_k[i] < 1:
sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32)
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicate check for top_k < 1 at lines 347-348 and 353-354. The second check (lines 353-354) is redundant and should be removed.

Suggested change
if sampling_params.top_k[i] < 1:
sampling_params.top_k[i] = 32 # k<1 means no restriction so set it to max k (32)

Copilot uses AI. Check for mistakes.
@@ -630,28 +743,40 @@ def _decode_easy_trace_text(

return trace_tok_rm
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When enable_split_sampling is True and return_logits is False, the function returns the result from self.model.sampling.sample() which should return a tuple (tt_tokens, tt_log_probs) according to line 645-650 of llama_model.py. However, when enable_split_sampling is False or return_logits is True, the function returns trace_tok_rm which is the raw output. This creates inconsistent return types. The caller at line 565 expects a tuple (tt_tok, tt_log_probs). Consider ensuring consistent return types.

Suggested change
return trace_tok_rm
# For consistency, always return (tt_out, tt_log_probs) where log_probs may be None.
return trace_tok_rm, None

Copilot uses AI. Check for mistakes.
# frequency
output_counts_bf16 = ttnn.typecast(context.output_counts, ttnn.bfloat16, **op_kwargs)

freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs)
Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intermediate variable output_counts_bf16 is created but never deallocated. This could lead to memory leaks in long-running scenarios. Consider adding output_counts_bf16.deallocate() after line 48.

Suggested change
freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs)
freq_term = ttnn.multiply(output_counts_bf16, context.frequency_penalties, **op_kwargs)
output_counts_bf16.deallocate()

Copilot uses AI. Check for mistakes.
Comment on lines +302 to +303
padded_batch = 32

Copy link

Copilot AI Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable padded_batch is not used.

Suggested change
padded_batch = 32

Copilot uses AI. Check for mistakes.
@djordje-tt djordje-tt force-pushed the divanovic/stable_llama3.3_70b branch from b4abed2 to 01544b5 Compare January 26, 2026 12:01
@sraizada-tt sraizada-tt requested a review from tchedaTT January 26, 2026 12:08
@yieldthought
Copy link
Contributor

Codex review:

The patch introduces at least one runtime break (missing SamplingGenerator.reset_seed) and a logic regression in batched prefill logits. It also changes a critical API keyword in log-prob calculation that can corrupt results when log_probs are enabled.

Full review comments:

  • [P1] Preserve SamplingGenerator.reset_seed API — models/common/sampling/generator.py:64
    SamplingGenerator now only initializes a SeedManager, but the public reset_seed method was removed; existing call sites such as models/tt_transformers/tt/generator.py still invoke sampling_module.reset_seed(...) when sampling on device. That call will raise AttributeError at runtime for any decode that uses sampling_params. Either keep a wrapper method on SamplingGenerator or update all call sites to use seed_manager.reset_seed()/get_new_values().

  • [P2] Avoid batched prefill when returning logits — models/demos/llama3_70b_galaxy/tt/generator.py:192
    The new use_batched_prefill condition no longer checks return_logits/tt_out_logits_all_users, so batched prefill now runs even when callers request host logits (sampling_params=None). In that path only tt_out_logits_all_users[id] (id=0 in batched mode) is populated, leaving other users' logits as zeros. This yields incorrect logits for batch>=16/seq_len=128. Reintroduce the guard or fill logits for every user in batched mode.

  • [P2] Use rounding_mode for log-probs chip-id div — models/common/utils.py:215
    In LogProbsCalculator._prepare_relevant_logits, ttnn.div is called with round_mode="floor". The TTSampling/TTNN API uses rounding_mode; the new keyword is ignored or errors, so chip IDs can be computed with floating division and the remainder/log_probs become incorrect when enable_log_probs is true. Use rounding_mode="floor" here to keep integer chip IDs.

@djordje-tt djordje-tt force-pushed the divanovic/stable_llama3.3_70b branch from 7a1b14e to e44be11 Compare February 4, 2026 12:48
@rdraskicTT rdraskicTT requested a review from a team as a code owner February 4, 2026 16:36
@djordje-tt djordje-tt force-pushed the divanovic/stable_llama3.3_70b branch from 9602978 to bdfb675 Compare February 6, 2026 14:27

x = self.lm_head(x)

if mode == "prefill":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it safe to remove this?

if self.model_config["LM_HEAD_INPUT_MEMCFG"].is_sharded():
logits = ttnn.interleaved_to_sharded(logits, self.model_config["LM_HEAD_INPUT_MEMCFG"])
logits = self.lm_head(logits)
logits = ttnn.to_layout(logits, layout=ttnn.ROW_MAJOR_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to remove this as well? is logits already in DRAM? does downstream usage of this function continue to expect ROW MAJOR? please double check nothing breaks here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipelines like Llama3.1-8b demo, vllm nightly for Llama3.1-8b and Models CI for Llama3.1-8b are passing. I can double check but I am pretty confident it's not needed anymore!

@djordje-tt djordje-tt added this pull request to the merge queue Feb 6, 2026
Merged via the queue into main with commit d406ab4 Feb 6, 2026
136 of 148 checks passed
@djordje-tt djordje-tt deleted the divanovic/stable_llama3.3_70b branch February 6, 2026 22:36
@djordje-tt djordje-tt restored the divanovic/stable_llama3.3_70b branch February 8, 2026 11:48
djordje-tt added a commit that referenced this pull request Feb 9, 2026
…ers (#36476)

#36325

This PR fixes couple of different issues for Llama3.3-70b:
- Non-uniform seeding
- Penalty trap bug
- Penalty bugs for Llama3.1-8b
- batched prefill determinism
- diff between batched and non-batched prefill
- missing logprobs support for Llama3.3-70b

- Fixes same sampling parameters for Llama3.1-8b
- Bring over the log-probs support for Galaxy (optional log-softmaxed
logits output), matching the behavior already validated on stable in
TT-Metal, vLLM nightly, and Models CI.
- Integrate the deterministic seeding flow (host-side RNG +
SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`)
so prefill + decode produce deterministic sequences across repeats when
seeds are fixed.
- Ensure the penalties path matches the shared implementation, fixing
the earlier divergence across users.
- Updated matmul configs to support same behaviour across batched and
non-batched prefill with couple additional fixes for divergence.

Performance numbers on text_demo in t/s/u:
| branch | without penalties | with penalties |
|-------|-------|-------|
| branch  | 71.88  t/s/u | 42.36  t/s/u |
| main  | 72.05  t/s/u | - |

**TTFT**:
**68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm
is expected.

- [ ] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046)

- [x] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284)
- [x] [vllm
nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050)
- [x] [Models
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475)

Last pipelines list 6th Feb:
- [] [vllm
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798)
- [] [Shield
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631)
- [] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409)

---------

Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com>
Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com>
Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com>
Co-authored-by: alnah005 <salnahari@tenstorrent.com>
Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com>
Co-authored-by: handrewsTT <handrews@tenstorrent.com>
Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com>
Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com>
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com>
Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
djordje-tt added a commit that referenced this pull request Feb 9, 2026
…ers (#36476)

#36325

This PR fixes couple of different issues for Llama3.3-70b:
- Non-uniform seeding
- Penalty trap bug
- Penalty bugs for Llama3.1-8b
- batched prefill determinism
- diff between batched and non-batched prefill
- missing logprobs support for Llama3.3-70b

- Fixes same sampling parameters for Llama3.1-8b
- Bring over the log-probs support for Galaxy (optional log-softmaxed
logits output), matching the behavior already validated on stable in
TT-Metal, vLLM nightly, and Models CI.
- Integrate the deterministic seeding flow (host-side RNG +
SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`)
so prefill + decode produce deterministic sequences across repeats when
seeds are fixed.
- Ensure the penalties path matches the shared implementation, fixing
the earlier divergence across users.
- Updated matmul configs to support same behaviour across batched and
non-batched prefill with couple additional fixes for divergence.

Performance numbers on text_demo in t/s/u:
| branch | without penalties | with penalties |
|-------|-------|-------|
| branch  | 71.88  t/s/u | 42.36  t/s/u |
| main  | 72.05  t/s/u | - |

**TTFT**:
**68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm
is expected.

- [ ] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046)

- [x] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284)
- [x] [vllm
nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050)
- [x] [Models
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475)

Last pipelines list 6th Feb:
- [] [vllm
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798)
- [] [Shield
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631)
- [] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409)

---------

Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com>
Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com>
Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com>
Co-authored-by: alnah005 <salnahari@tenstorrent.com>
Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com>
Co-authored-by: handrewsTT <handrews@tenstorrent.com>
Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com>
Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com>
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com>
Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
dpopovTT pushed a commit that referenced this pull request Feb 9, 2026
…ers (#36476)

#36325
This PR fixes couple of different issues for Llama3.3-70b:
- Non-uniform seeding
- Penalty trap bug
- Penalty bugs for Llama3.1-8b
- batched prefill determinism
- diff between batched and non-batched prefill
- missing logprobs support for Llama3.3-70b

- Fixes same sampling parameters for Llama3.1-8b
- Bring over the log-probs support for Galaxy (optional log-softmaxed
logits output), matching the behavior already validated on stable in
TT-Metal, vLLM nightly, and Models CI.
- Integrate the deterministic seeding flow (host-side RNG +
SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`)
so prefill + decode produce deterministic sequences across repeats when
seeds are fixed.
- Ensure the penalties path matches the shared implementation, fixing
the earlier divergence across users.
- Updated matmul configs to support same behaviour across batched and
non-batched prefill with couple additional fixes for divergence.

Performance numbers on text_demo in t/s/u:
| branch | without penalties | with penalties |
|-------|-------|-------|
| branch  | 71.88  t/s/u | 42.36  t/s/u |
| main  | 72.05  t/s/u | - |

**TTFT**:
**68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm
is expected.

- [ ] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046)

- [x] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284)
- [x] [vllm
nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050)
- [x] [Models
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475)

Last pipelines list 6th Feb:
- [] [vllm
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798)
- [] [Shield
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631)
- [] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409)

---------

Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com>
Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com>
Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com>
Co-authored-by: alnah005 <salnahari@tenstorrent.com>
Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com>
Co-authored-by: handrewsTT <handrews@tenstorrent.com>
Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com>
Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com>
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com>
Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
adrian-pascual-bernal pushed a commit that referenced this pull request Feb 10, 2026
…ers (#36476)

### Ticket
#36325

### Problem description
This PR fixes couple of different issues for Llama3.3-70b:
- Non-uniform seeding
- Penalty trap bug
- Penalty bugs for Llama3.1-8b
- batched prefill determinism
- diff between batched and non-batched prefill
- missing logprobs support for Llama3.3-70b

- Fixes same sampling parameters for Llama3.1-8b
### What's changed
- Bring over the log-probs support for Galaxy (optional log-softmaxed
logits output), matching the behavior already validated on stable in
TT-Metal, vLLM nightly, and Models CI.
- Integrate the deterministic seeding flow (host-side RNG +
SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`)
so prefill + decode produce deterministic sequences across repeats when
seeds are fixed.
- Ensure the penalties path matches the shared implementation, fixing
the earlier divergence across users.
- Updated matmul configs to support same behaviour across batched and
non-batched prefill with couple additional fixes for divergence.

Performance numbers on text_demo in t/s/u: 
| branch | without penalties | with penalties |
|-------|-------|-------|
| branch  | 71.88  t/s/u | 42.36  t/s/u |
| main  | 72.05  t/s/u | - |

**TTFT**:
**68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm
is expected.

### Checklist

- [ ] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046)



#### Model tests

- [x] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284)
- [x] [vllm
nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050)
- [x] [Models
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475)

Last pipelines list 6th Feb: 
- [] [vllm
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798)
- [] [Shield
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631)
- [] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409)

---------

Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com>
Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com>
Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com>
Co-authored-by: alnah005 <salnahari@tenstorrent.com>
Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com>
Co-authored-by: handrewsTT <handrews@tenstorrent.com>
Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com>
Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com>
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com>
Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
ssundaramTT pushed a commit that referenced this pull request Feb 10, 2026
…ers (#36476)

### Ticket
#36325

### Problem description
This PR fixes couple of different issues for Llama3.3-70b:
- Non-uniform seeding
- Penalty trap bug
- Penalty bugs for Llama3.1-8b
- batched prefill determinism
- diff between batched and non-batched prefill
- missing logprobs support for Llama3.3-70b

- Fixes same sampling parameters for Llama3.1-8b
### What's changed
- Bring over the log-probs support for Galaxy (optional log-softmaxed
logits output), matching the behavior already validated on stable in
TT-Metal, vLLM nightly, and Models CI.
- Integrate the deterministic seeding flow (host-side RNG +
SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`)
so prefill + decode produce deterministic sequences across repeats when
seeds are fixed.
- Ensure the penalties path matches the shared implementation, fixing
the earlier divergence across users.
- Updated matmul configs to support same behaviour across batched and
non-batched prefill with couple additional fixes for divergence.

Performance numbers on text_demo in t/s/u: 
| branch | without penalties | with penalties |
|-------|-------|-------|
| branch  | 71.88  t/s/u | 42.36  t/s/u |
| main  | 72.05  t/s/u | - |

**TTFT**:
**68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm
is expected.

### Checklist

- [ ] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046)



#### Model tests

- [x] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284)
- [x] [vllm
nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050)
- [x] [Models
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475)

Last pipelines list 6th Feb: 
- [] [vllm
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798)
- [] [Shield
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631)
- [] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409)

---------

Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com>
Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com>
Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com>
Co-authored-by: alnah005 <salnahari@tenstorrent.com>
Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com>
Co-authored-by: handrewsTT <handrews@tenstorrent.com>
Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com>
Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com>
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com>
Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
bgoelTT pushed a commit that referenced this pull request Feb 10, 2026
…ers (#36476)

#36325
This PR fixes couple of different issues for Llama3.3-70b:
- Non-uniform seeding
- Penalty trap bug
- Penalty bugs for Llama3.1-8b
- batched prefill determinism
- diff between batched and non-batched prefill
- missing logprobs support for Llama3.3-70b

- Fixes same sampling parameters for Llama3.1-8b
- Bring over the log-probs support for Galaxy (optional log-softmaxed
logits output), matching the behavior already validated on stable in
TT-Metal, vLLM nightly, and Models CI.
- Integrate the deterministic seeding flow (host-side RNG +
SamplingSeedManager + `ttnn.manual_seed` usage before `ttnn.sampling`)
so prefill + decode produce deterministic sequences across repeats when
seeds are fixed.
- Ensure the penalties path matches the shared implementation, fixing
the earlier divergence across users.
- Updated matmul configs to support same behaviour across batched and
non-batched prefill with couple additional fixes for divergence.

Performance numbers on text_demo in t/s/u:
| branch | without penalties | with penalties |
|-------|-------|-------|
| branch  | 71.88  t/s/u | 42.36  t/s/u |
| main  | 72.05  t/s/u | - |

**TTFT**:
**68.5**ms -> **73.9**ms drop due to disabling use_2d_grid in rms norm
is expected.

- [ ] [All post-commit
tests](https://github.com/tenstorrent/tt-metal/actions/runs/21355526046)

- [x] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21361481284)
- [x] [vllm
nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21361542050)
- [x] [Models
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21435406349/job/61728802475)

Last pipelines list 6th Feb:
- [] [vllm
Nightly](https://github.com/tenstorrent/tt-metal/actions/runs/21754091798)
- [] [Shield
CI](https://github.com/tenstorrent/tt-shield/actions/runs/21753926206/job/62758873631)
- [] [Galaxy
Demo](https://github.com/tenstorrent/tt-metal/actions/runs/21754402409)

---------

Co-authored-by: Stuti Raizada <159130512+sraizada-tt@users.noreply.github.com>
Co-authored-by: Tomasz Cheda <tcheda@tenstorrent.com>
Co-authored-by: Jonathan Su <jonathansu@tenstorrent.com>
Co-authored-by: alnah005 <salnahari@tenstorrent.com>
Co-authored-by: Alberto Perez Vicente <aperezvicente@tenstorrent.com>
Co-authored-by: handrewsTT <handrews@tenstorrent.com>
Co-authored-by: Mohamed Bahnas <mbahnas@tenstorrent.com>
Co-authored-by: Radoica Draskic <rdraskic@tenstorrent.com>
Co-authored-by: kpaigwar <kpaigwar@tenstorrent.com>
Co-authored-by: Stuti Raizada <sraizada@tenstorrent.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.